package topics

import (
	"context"
	"database/sql"
	"encoding/json"
	"errors"
	"strconv"
	"strings"
	"time"

	"github.com/go-kit/kit/log"
	redis "github.com/go-redis/redis/v8"
	_ "github.com/go-sql-driver/mysql"
)

// SearchOpt是仓储中Search方法的选项配置
type SearchOpts struct {
	// 需要筛选Enabled时需要初始化Enabled为true
	Enabled   bool
	// 需要筛选IsDefault时需要初始化IsDefault为true,目前已经默认关键字专题为非默认专题此字段可以留作扩展使用
	IsDefault bool
}

type Repository interface {
	// 保存整个专题
	Save(ctx context.Context, topic *Topic) error
	// 根据id删除单个专题
	Delete(ctx context.Context, id string) error
	// 根据id获取单个专题
	Get(ctx context.Context, id string) (*Topic, error)
	// 全量搜索, searchOpts为搜索条件
	// example：
	// Search(ctx, opts{enabled: true}, true)
	// select ... from . where enabled = true;
	// Search(ctx, opts{enabled: true, isDefault: true}, true, false)
	// select ... from . where (enabled = true) and (isDefault = false);
	Search(ctx context.Context, opts SearchOpts, args ...interface{}) ([]*Topic, error)
	// 获取最新的版本号
	GetVersion(ctx context.Context) (*int64, error)
}

type repository struct {
	sqlClient *sql.DB
	logger    log.Logger
}

func NewRepository(sqlClient *sql.DB, logger log.Logger) *repository {
	return &repository{
		sqlClient: sqlClient,
		logger:    logger,
	}
}

// todo
// topicRow是对接topic
type topicRow struct {
	Id              string
	Name            string
	Enabled         bool
	IsDefault       bool
	Keywords        sql.NullString
	BeginTimestamp  sql.NullInt64
	EndTimestamp    sql.NullInt64
	CreateTimestamp int64
	UpdateTimestamp int64
	ParentId        sql.NullString
	Comment         sql.NullString
	Version         int64
}

func newTopicRow(topic *Topic) *topicRow {
	var r topicRow
	r.Id = topic.Id
	r.Name = topic.Name
	r.Enabled = topic.Enabled
	r.IsDefault = topic.IsDefault
	r.Version = topic.Version
	if len(topic.Keywords) > 0 {
		b, _ := json.Marshal(topic.Keywords)
		r.SetKeywords(string(b))
	}
	if topic.BeginTimestamp != nil {
		r.SetBeginTimestamp(*topic.BeginTimestamp)
	}
	if topic.EndTimestamp != nil {
		r.SetEndTimestamp(*topic.EndTimestamp)
	}
	r.CreateTimestamp = topic.CreateTimestamp
	r.UpdateTimestamp = topic.UpdateTimestamp
	if topic.ParentId != nil {
		r.SetParentId(*topic.ParentId)
	}
	if topic.Comment != nil {
		r.SetComment(*topic.Comment)
	}
	return &r
}

func (tr *topicRow) Topic() *Topic {
	topic := &Topic{
		Id:              tr.Id,
		Name:            tr.Name,
		Enabled:         tr.Enabled,
		IsDefault:       tr.IsDefault,
		CreateTimestamp: tr.CreateTimestamp,
		UpdateTimestamp: tr.UpdateTimestamp,
		Version:         tr.Version,
	}
	if tr.Keywords.Valid {
		var keywords []Keywords
		_ = json.Unmarshal([]byte(tr.Keywords.String), &keywords)
		topic.Keywords = keywords
	}
	if tr.BeginTimestamp.Valid {
		topic.SetBeginTimestamp(tr.BeginTimestamp.Int64)
	}
	if tr.EndTimestamp.Valid {
		topic.SetEndTimestamp(tr.EndTimestamp.Int64)
	}
	if tr.ParentId.Valid {
		topic.SetParentId(tr.ParentId.String)
	}
	if tr.Comment.Valid {
		topic.SetComment(tr.Comment.String)
	}
	return topic
}

func (r *topicRow) SetKeywords(keywords string) {
	r.Keywords = sql.NullString{String: keywords, Valid: true}
}

func (r *topicRow) SetBeginTimestamp(beginTime int64) {
	r.BeginTimestamp = sql.NullInt64{Int64: beginTime, Valid: true}
}

func (r *topicRow) SetEndTimestamp(endTime int64) {
	r.EndTimestamp = sql.NullInt64{Int64: endTime, Valid: true}
}

func (r *topicRow) SetParentId(parentId string) {
	r.ParentId = sql.NullString{String: parentId, Valid: true}
}

func (r *topicRow) SetComment(comment string) {
	r.Comment = sql.NullString{String: comment, Valid: true}
}

func (r *repository) Save(ctx context.Context, topic *Topic) error {
	tr := newTopicRow(topic)
	// id, name, enabled, isDefault, keywords, beginTime, endTime, updateTime, parentId,
	_, err := r.sqlClient.Exec(
		`REPLACE INTO
         topic(
             id,
             name,
             enabled,
             isDefault,
             keywords,
             beginTimestamp,
             endTimestamp,
             createTimestamp,
             updateTimestamp,
             parentId,
             comment,
             version
         )
         VALUES(
             ?,
             ?,
             ?,
             ?,
             ?,
             ?,
             ?,
             ?,
             ?,
             ?,
			 ?,
			 ?
		 );`,
		tr.Id,
		tr.Name,
		tr.Enabled,
		tr.IsDefault,
		tr.Keywords,
		tr.BeginTimestamp,
		tr.EndTimestamp,
		tr.CreateTimestamp,
		tr.UpdateTimestamp,
		tr.ParentId,
		tr.Comment,
		tr.Version,
	)
	return err
}

func (r *repository) Delete(ctx context.Context, id string) error {
	_ = r.logger.Log("delete user:", id)
	_, err := r.sqlClient.Exec("DELETE FROM topic WHERE id = ?;", id)
	return err
}

func (r *repository) Search(ctx context.Context, opts SearchOpts, args ...interface{}) ([]*Topic, error) {
	var baseQuery = `SELECT
		id,
		name,
		enabled,
		isDefault,
		keywords,
		beginTimestamp,
		endTimestamp,
		createTimestamp,
		updateTimestamp,
		parentId,
		comment,
		version
	FROM topic`
	var query []string
	if opts.Enabled {
		query = append(query, `enabled = ?`)
	}
	if opts.IsDefault {
		query = append(query, `isDefault = ?`)
	}
	if len(query) > 0 {
		baseQuery += " WHERE " + strings.Join(query, " AND ")
	}
	baseQuery += `;`
	rows, err := r.sqlClient.Query(baseQuery, args...)
	if err != nil {
		return nil, err
	}
	defer func(){
		err := rows.Close()
		if err != nil{
			_ = r.logger.Log("mysql error", err.Error())
		}
	}()
	var topics []*Topic
	for rows.Next() {
		var tr topicRow
		err := rows.Scan(
			&tr.Id,
			&tr.Name,
			&tr.Enabled,
			&tr.IsDefault,
			&tr.Keywords,
			&tr.BeginTimestamp,
			&tr.EndTimestamp,
			&tr.CreateTimestamp,
			&tr.UpdateTimestamp,
			&tr.ParentId,
			&tr.Comment,
			&tr.Version,
		)
		if err != nil {
			_ = r.logger.Log("error", err.Error())
			continue
		}
		topics = append(topics, tr.Topic())
	}
	return topics, nil
}

var ErrTopicNotFound = errors.New("专题不存在")

func (r *repository) Get(ctx context.Context, id string) (*Topic, error) {
	var tr topicRow
	err := r.sqlClient.QueryRow(
		`SELECT
		     id,
		     name,
		     enabled,
		     isDefault,
		     keywords,
		     beginTimestamp,
		     endTimestamp,
		     createTimestamp,
		     updateTimestamp,
		     parentId,
		     comment,
		     version
		 FROM topic
		 WHERE id = ?;`,
		id,
	).Scan(
		&tr.Id,
		&tr.Name,
		&tr.Enabled,
		&tr.IsDefault,
		&tr.Keywords,
		&tr.BeginTimestamp,
		&tr.EndTimestamp,
		&tr.CreateTimestamp,
		&tr.UpdateTimestamp,
		&tr.ParentId,
		&tr.Comment,
		&tr.Version,
	)
	if err != nil {
		if err == sql.ErrNoRows {
			return nil, ErrTopicNotFound
		}
		return nil, err
	}
	return tr.Topic(), nil
}

func (r *repository) GetVersion(ctx context.Context) (*int64, error){
	var version sql.NullInt64
	err := r.sqlClient.QueryRow(`
		SELECT 
			MAX(version)
		FROM topic;
	`).Scan(&version)
	if err != nil{
		_ = r.logger.Log("mysql error", err.Error())
		return nil, err
	}
	if version.Valid{
		return &version.Int64, nil
	}
	return nil, nil
}

type upgradeRepository struct {
	redisClient *redis.Client
}

type UpgradeRepository interface {
	// 将redis中的版本号更新为最新版本号，并将数据进行更新
	// 参数：  key：redis中数据的版本号   data：redis中应该保存的关键字专题最新的数据    topicVersion：最新的版本号(mysql中)
	Upgrade(ctx context.Context, key, data string, topicVersion int64) error
	// 获取redis中的latest_version版本号
	GetVersion(ctx context.Context, key string) (int64, error)
}

func NewUpgradeRepository(rClient *redis.Client) *upgradeRepository {
	return &upgradeRepository{
		redisClient: rClient,
	}
}

func (repo *upgradeRepository) Upgrade(ctx context.Context, key, data string, topicVersion int64) error {
	txf := func(tx *redis.Tx) error {
		var update bool
		oldVersion, err := tx.Get(ctx, key).Int64()
		if err != nil && err != redis.Nil {
			return err
		} else if err != redis.Nil {
			update = true
		}

		_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
			pipe.Set(ctx, key, topicVersion, 0)
			pipe.Set(ctx, strconv.FormatInt(topicVersion, 10), string(data), 0)
			if update {
				pipe.Expire(ctx, strconv.FormatInt(oldVersion, 10), time.Duration(time.Second)*3600*1)
			}
			return nil
		})
		return err
	}
	routineCount := 5
	for retries := routineCount; retries > 0; retries-- {
		err := repo.redisClient.Watch(ctx, txf, key)
		if err == redis.TxFailedErr {
			continue
		} else if err != nil {
			return err
		} else {
			break
		}
	}
	return nil
}

func (repo *upgradeRepository) GetVersion(ctx context.Context, key string) (int64, error) {
	return repo.redisClient.Get(ctx, key).Int64()
}
